{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Under the Hood of Encrypted Neural Networks\n",
    "\n",
    "This tutorial is optional, and can be skipped without loss of continuity.\n",
    "\n",
    "In this tutorial, we'll take a look at how CrypTen performs inference with an encrypted neural network on encrypted data. We'll see how the data remains encrypted through all the operations, and yet is able to obtain accurate results after the computation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:module 'torchvision.models.mobilenet' has no attribute 'ConvBNReLU'\n"
     ]
    }
   ],
   "source": [
    "import crypten\n",
    "import torch\n",
    "\n",
    "crypten.init() \n",
    "torch.set_num_threads(1)\n",
    "\n",
    "# Ignore warnings\n",
    "import warnings; \n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Keep track of all created temporary files so that we can clean up at the end\n",
    "temp_files = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A Simple Linear Layer\n",
    "We'll start by examining how a single Linear layer works in CrypTen. We'll instantiate a torch Linear layer, convert to CrypTen layer, encrypt it, and step through some toy data with it. As in earlier tutorials, we'll assume Alice has the rank 0 process and Bob has the rank 1 process. We'll also assume Alice has the layer and Bob has the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define ALICE and BOB src values\n",
    "ALICE = 0\n",
    "BOB = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Plaintext Weights:\n",
      "\n",
      " Parameter containing:\n",
      "tensor([[-0.4258, -0.4130, -0.4771, -0.4829],\n",
      "        [-0.0176,  0.4334,  0.4345,  0.1994]], requires_grad=True)\n",
      "\n",
      "Plaintext Bias:\n",
      "\n",
      " Parameter containing:\n",
      "tensor([-0.3670, -0.3700], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "# Instantiate single Linear layer\n",
    "layer_linear = nn.Linear(4, 2)\n",
    "\n",
    "# The weights and the bias are initialized to small random values\n",
    "print(\"Plaintext Weights:\\n\\n\", layer_linear._parameters['weight'])\n",
    "print(\"\\nPlaintext Bias:\\n\\n\", layer_linear._parameters['bias'])\n",
    "\n",
    "# Save the plaintext layer\n",
    "layer_linear_file = \"/tmp/tutorial5_layer_alice1.pth\"\n",
    "crypten.save(layer_linear, layer_linear_file)\n",
    "temp_files.append(layer_linear_file) \n",
    "\n",
    "# Generate some toy data\n",
    "features = 4\n",
    "examples = 3\n",
    "toy_data = torch.rand(examples, features)\n",
    "\n",
    "# Save the plaintext toy data\n",
    "toy_data_file = \"/tmp/tutorial5_data_bob1.pth\"\n",
    "crypten.save(toy_data, toy_data_file)\n",
    "temp_files.append(toy_data_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights:\n",
      " tensor([[ 4809765439338565026, -8528297308125766579, -9049258375232459840,\n",
      "          2844226043919239965],\n",
      "        [-3908530731123408934,  1774359065878558133,  2533144602464543266,\n",
      "         -3356227231624953279]])\n",
      "Bias:\n",
      " tensor([-7879285751097988616, -1952849712014642234]) \n",
      "\n",
      "Decrypted result:\n",
      " tensor([[-1.0097,  0.1116],\n",
      "        [-1.5431,  0.2374],\n",
      "        [-1.3470,  0.1359]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import crypten.mpc as mpc\n",
    "import crypten.communicator as comm\n",
    "\n",
    "@mpc.run_multiprocess(world_size=2)\n",
    "def forward_single_encrypted_layer():\n",
    "    # Load and encrypt the layer\n",
    "    layer = crypten.load_from_party(layer_linear_file, src=ALICE)\n",
    "    layer_enc = crypten.nn.from_pytorch(layer, dummy_input=torch.empty((1,4)))\n",
    "    layer_enc.encrypt(src=ALICE)\n",
    "    \n",
    "    # Note that layer parameters are encrypted:\n",
    "    crypten.print(\"Weights:\\n\", layer_enc.weight.share)\n",
    "    crypten.print(\"Bias:\\n\", layer_enc.bias.share, \"\\n\")\n",
    "    \n",
    "    # Load and encrypt data\n",
    "    data_enc = crypten.load_from_party(toy_data_file, src=BOB)\n",
    "    \n",
    "    # Apply the encrypted layer (linear transformation):\n",
    "    result_enc = layer_enc.forward(data_enc)\n",
    "    \n",
    "    # Decrypt the result:\n",
    "    result = result_enc.get_plain_text()\n",
    "    \n",
    "    # Examine the result\n",
    "    crypten.print(\"Decrypted result:\\n\", result)\n",
    "        \n",
    "forward_single_encrypted_layer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the application of the encrypted linear layer on the encrypted data produces an encrypted result, which we can then decrypt to get the values in plaintext.\n",
    "\n",
    "Let's look at a second linear transformation, to give a flavor of how accuracy is preserved even when the data and the layer are encrypted. We'll look at a uniform scaling transformation, in which all tensor elements are multiplied by the same scalar factor. Again, we'll assume Alice has the layer and the rank 0 process, and Bob has the data and the rank 1 process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a linear layer with random weights\n",
    "layer_scale = nn.Linear(3, 3)\n",
    "\n",
    "# Construct a uniform scaling matrix: we'll scale by factor 5\n",
    "factor = 5\n",
    "layer_scale._parameters['weight'] = torch.eye(3)*factor\n",
    "layer_scale._parameters['bias'] = torch.zeros_like(layer_scale._parameters['bias'])\n",
    "\n",
    "# Save the plaintext layer\n",
    "layer_scale_file = \"/tmp/tutorial5_layer_alice2.pth\"\n",
    "crypten.save(layer_scale, layer_scale_file)\n",
    "temp_files.append(layer_scale_file)\n",
    "\n",
    "# Construct some toy data\n",
    "features = 3\n",
    "examples = 2\n",
    "toy_data = torch.ones(examples, features)\n",
    "\n",
    "# Save the plaintext toy data\n",
    "toy_data_file = \"/tmp/tutorial5_data_bob2.pth\"\n",
    "crypten.save(toy_data, toy_data_file)\n",
    "temp_files.append(toy_data_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights:\n",
      " tensor([[ 1048173932012528739, -8485385189168951104, -4402857909290748548],\n",
      "        [ 4553819400594206674, -7570443770611756974, -5679305187004366181],\n",
      "        [-1176794114084926816,   871515787882802395,  6055415794932662708]])\n",
      "Bias:\n",
      "\n",
      " tensor([-2824390750419204213,  3169342280882583467,  2952982236928138320])\n",
      "Plaintext result:\n",
      " tensor([[5., 5., 5.],\n",
      "        [5., 5., 5.]])\n"
     ]
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def forward_scaling_layer():\n",
    "    rank = comm.get().get_rank()\n",
    "    \n",
    "    # Load and encrypt the layer\n",
    "    layer = crypten.load_from_party(layer_scale_file, src=ALICE)\n",
    "    layer_enc = crypten.nn.from_pytorch(layer, dummy_input=torch.empty((1,3)))\n",
    "    layer_enc.encrypt(src=ALICE)\n",
    "    \n",
    "    # Load and encrypt data\n",
    "    data_enc = crypten.load_from_party(toy_data_file, src=BOB)   \n",
    "    \n",
    "    # Note that layer parameters are (still) encrypted:\n",
    "    crypten.print(\"Weights:\\n\", layer_enc.weight)\n",
    "    crypten.print(\"Bias:\\n\\n\", layer_enc.bias)\n",
    "\n",
    "    # Apply the encrypted scaling transformation\n",
    "    result_enc = layer_enc.forward(data_enc)\n",
    "\n",
    "    # Decrypt the result:\n",
    "    result = result_enc.get_plain_text()\n",
    "    crypten.print(\"Plaintext result:\\n\", (result))\n",
    "        \n",
    "z = forward_scaling_layer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The resulting plaintext tensor is correctly scaled, even though we applied the encrypted transformation on the encrypted input! "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-layer Neural Networks\n",
    "Let's now look at how the encrypted input moves through an encrypted multi-layer neural network. \n",
    "\n",
    "For ease of explanation, we'll first step through a network with only two linear layers and ReLU activations. Again, we'll assume Alice has a network and Bob has some data, and they wish to run encrypted inference. \n",
    "\n",
    "To simulate this, we'll once again generate some toy data and train Alice's network on it. Then we'll encrypt Alice's network, Bob's data, and step through every layer in the network with the encrypted data. Through this, we'll see how the computations get applied although the network and the data are encrypted.\n",
    "\n",
    "### Setup\n",
    "As in Tutorial 3, we will first generate 1000 ground truth samples using 50 features and a randomly generated hyperplane to separate positive and negative examples. We will then modify the labels so that they are all non-negative. Finally, we will split the data so that the first 900 samples belong to Alice and the last 100 samples belong to Bob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "features = 50\n",
    "examples = 1000\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "torch.manual_seed(1)\n",
    "\n",
    "# Generate toy data and separating hyperplane\n",
    "data = torch.randn(examples, features)\n",
    "w_true = torch.randn(1, features)\n",
    "b_true = torch.randn(1)\n",
    "labels = w_true.matmul(data.t()).add(b_true).sign()\n",
    "\n",
    "# Change labels to non-negative values\n",
    "labels_nn = torch.where(labels==-1, torch.zeros(labels.size()), labels)\n",
    "labels_nn = labels_nn.squeeze().long()\n",
    "\n",
    "# Split data into Alice's and Bob's portions:\n",
    "data_alice, labels_alice = data[:900], labels_nn[:900]\n",
    "data_bob, labels_bob = data[900:], labels_nn[900:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Alice's network\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class AliceNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AliceNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(50, 20)\n",
    "        self.fc2 = nn.Linear(20, 2)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99 Loss: 0.24704290926456451\n",
      "Epoch 199 Loss: 0.08965438604354858\n",
      "Epoch 299 Loss: 0.05166155472397804\n",
      "Epoch 399 Loss: 0.03510778397321701\n",
      "Epoch 499 Loss: 0.026072457432746887\n"
     ]
    }
   ],
   "source": [
    "# Train and save Alice's network\n",
    "model = AliceNet()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n",
    "\n",
    "for i in range(500):  \n",
    "    #forward pass: compute prediction\n",
    "    output = model(data_alice)\n",
    "    \n",
    "    #compute and print loss\n",
    "    loss = criterion(output, labels_alice)\n",
    "    if i % 100 == 99:\n",
    "        print(\"Epoch\", i, \"Loss:\", loss.item())\n",
    "    \n",
    "    #zero gradients for learnable parameters\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    #backward pass: compute gradient with respect to model parameters\n",
    "    loss.backward()\n",
    "    \n",
    "    #update model parameters\n",
    "    optimizer.step()\n",
    "\n",
    "sample_trained_model_file = '/tmp/tutorial5_alice_model.pth'\n",
    "torch.save(model, sample_trained_model_file)\n",
    "temp_files.append(sample_trained_model_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stepping through a Multi-layer Network\n",
    "\n",
    "Let's now look at what happens when we load the network Alice's has trained and encrypt it. First, we'll look at how the network structure changes when we convert it from a PyTorch network to CrypTen network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name: 5 \tModule: Linear encrypted module\n",
      "Name: 6 \tModule: ReLU encrypted module\n",
      "Name: output \tModule: Linear encrypted module\n"
     ]
    }
   ],
   "source": [
    "# Load the trained network to Alice\n",
    "model_plaintext = crypten.load(sample_trained_model_file, model_class=AliceNet, src=ALICE)\n",
    "\n",
    "# Convert the trained network to CrypTen network \n",
    "private_model = crypten.nn.from_pytorch(model_plaintext, dummy_input=torch.empty((1, 50)))\n",
    "# Encrypt the network\n",
    "private_model.encrypt(src=ALICE)\n",
    "\n",
    "# Examine the structure of the encrypted CrypTen network\n",
    "for name, curr_module in private_model._modules.items():\n",
    "    print(\"Name:\", name, \"\\tModule:\", curr_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the encrypted network has 3 modules, named '5', '6' and 'output', denoting the first Linear layer, the ReLU activation, and the second Linear layer respectively. These modules are encrypted just as the layers in the previous section were. \n",
    "\n",
    "Now let's encrypt Bob's data, and step it through each encrypted module. For readability, we will use only 3 examples from Bob's data to illustrate the inference. Note how Bob's data remains encrypted after each individual layer's computation!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-processing: Select only the first three examples in Bob's data for readability\n",
    "data = data_bob[:3]\n",
    "sample_data_bob_file = '/tmp/tutorial5_data_bob3.pth'\n",
    "torch.save(data, sample_data_bob_file)\n",
    "temp_files.append(sample_data_bob_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rank: 0\n",
      "\tFirst Linear Layer: Output Encrypted: True\n",
      "Rank: 1\n",
      "\tFirst Linear Layer: Output Encrypted: True\n",
      "Rank: 0\n",
      "\tShares after First Linear Layer:tensor([[ 1294104436359772009,   175627571159595069,  6642614375759992943,\n",
      "          6081499958679808446,  7458846727696033441,  3147092037857368600,\n",
      "         -3046033583744813218,  3656714442306812623, -1444801573761934624,\n",
      "         -7296071273089953465,  8612281116080325317,  3767966956893906701,\n",
      "          1168460730848498997,  6391421666103853539,  2248247350390000515,\n",
      "          3213300948547734905, -1008946278573850451,  3023665117027050103,\n",
      "          6452227059867486924,  5122606758130927186],\n",
      "        [ 1294098018270385457,   175572314767882774,  6642679547617989340,\n",
      "          6081490500342724210,  7458890389841627996,  3146958713752628383,\n",
      "         -3046031468676736309,  3656827531990650240, -1444926189414126081,\n",
      "         -7295908481581646714,  8612296818825007185,  3767896044585741179,\n",
      "          1168492541025886293,  6391421981103357070,  2248478059279583032,\n",
      "          3213072974367093577, -1008986944565666746,  3023726324856312448,\n",
      "          6452034988922444717,  5122807803705415836],\n",
      "        [ 1294230301877800966,   175637085163985991,  6642453136769952687,\n",
      "          6081713335905482773,  7458898869236110557,  3147122860701412511,\n",
      "         -3046035755646826395,  3656960256252269441, -1444781735284474806,\n",
      "         -7296033774970670019,  8612497350047865696,  3767842152663519175,\n",
      "          1168515149450951187,  6391351342609178513,  2248345649544202536,\n",
      "          3213240247148552395, -1009008179631440065,  3023561862620483044,\n",
      "          6452162933179580076,  5122675530857862209]])\n",
      "Rank: 1\n",
      "\tShares after First Linear Layer:tensor([[-1294104436359773334,  -175627571159599635, -6642614375759997860,\n",
      "         -6081499958679789837, -7458846727696016852, -3147092037857415289,\n",
      "          3046033583744823616, -3656714442306790117,  1444801573761958273,\n",
      "          7296071273089957293, -8612281116080315108, -3767966956893894700,\n",
      "         -1168460730848401264, -6391421666103901272, -2248247350389922683,\n",
      "         -3213300948547715500,  1008946278573893179, -3023665117027054537,\n",
      "         -6452227059867552203, -5122606758130943834],\n",
      "        [-1294098018270317155,  -175572314767931912, -6642679547617914147,\n",
      "         -6081490500342719925, -7458890389841616593, -3146958713752681542,\n",
      "          3046031468676771016, -3656827531990574613,  1444926189414140031,\n",
      "          7295908481581578942, -8612296818825047980, -3767896044585702015,\n",
      "         -1168492541025970234, -6391421981103363019, -2248478059279514797,\n",
      "         -3213072974367128391,  1008986944565682293, -3023726324856344564,\n",
      "         -6452034988922471079, -5122807803705298654],\n",
      "        [-1294230301877803129,  -175637085163946559, -6642453136769800546,\n",
      "         -6081713335905393451, -7458898869236133111, -3147122860701417556,\n",
      "          3046035755646800937, -3656960256252185447,  1444781735284458812,\n",
      "          7296033774970745398, -8612497350047781551, -3767842152663546097,\n",
      "         -1168515149451053385, -6391351342609222152, -2248345649544305325,\n",
      "         -3213240247148642863,  1009008179631340188, -3023561862620490211,\n",
      "         -6452162933179517464, -5122675530857881919]])\n",
      "Rank: 0\n",
      "\tReLU:\n",
      " Output Encrypted: True\n",
      "Rank: 1\n",
      "\tReLU:\n",
      " Output Encrypted: True\n",
      "Rank: 0\n",
      "\tShares after ReLU: tensor([[ -69418112779619, -104819378117147,  -53821829423311,   -4702416953602,\n",
      "           81692254766900,   26354342345479,   19910955357038,   99396979972767,\n",
      "          -15535301645722,  116395626260581,   57574295823065,  -31804488560163,\n",
      "           -6157843889001,  127929889663321,   -9682178820609,  -34987980601125,\n",
      "          -66205526212571,  -33563851263958,  114769799056332,  110462795743251],\n",
      "        [ -71607331589659,  -67637186468049,   93445435998571, -107752143382447,\n",
      "          -75294651671773,  -25877267189444,  -45188604229407,   23555656170411,\n",
      "         -139958107653092,  100501491982742,   53781245538104,   17887297553747,\n",
      "          119527271279592,  -20996799538528,  -56976576213408,  138826578103652,\n",
      "           89363069875799,   43441519497618,  -10905259880145,   47079313829219],\n",
      "        [-129245172456225,  119322793913743,    1233960650848,  -23553427467120,\n",
      "           15852207530575,  -47833042200528,   20662867237334, -111298739886237,\n",
      "           83819509957564, -125551541691829,  -54474249003337,  -44663410212654,\n",
      "           31257650077361,  -34896836197447,  -55792830223424,  101913793004469,\n",
      "           -3443767350948, -140066874002600,  -65121025755447,  -74467758976736]])\n",
      "\n",
      "Rank: 1\n",
      "\tShares after ReLU: tensor([[  69418112779619,  104819378117147,   53821829423311,    4702416972211,\n",
      "          -81692254750311,  -26354342345479,  -19910955346640,  -99396979950261,\n",
      "           15535301669371, -116395626256753,  -57574295812856,   31804488572164,\n",
      "            6157843986734, -127929889663321,    9682178898441,   34987980620530,\n",
      "           66205526255299,   33563851263958, -114769799056332, -110462795743251],\n",
      "        [  71607331657961,   67637186468049,  -93445435923378,  107752143386732,\n",
      "           75294651683176,   25877267189444,   45188604264114,  -23555656094784,\n",
      "          139958107667042, -100501491982742,  -53781245538104,  -17887297514583,\n",
      "         -119527271279592,   20996799538528,   56976576281643, -138826578103652,\n",
      "          -89363069860252,  -43441519497618,   10905259880145,  -47079313712037],\n",
      "        [ 129245172456225, -119322793874311,   -1233960498707,   23553427556442,\n",
      "          -15852207530575,   47833042200528,  -20662867237334,  111298739970231,\n",
      "          -83819509957564,  125551541767208,   54474249087482,   44663410212654,\n",
      "          -31257650077361,   34896836197447,   55792830223424, -101913793004469,\n",
      "            3443767350948,  140066874002600,   65121025818059,   74467758976736]])\n",
      "\n",
      "Rank: 0 Second Linear layer:\n",
      " Output Encrypted: True\n",
      "\n",
      "Rank: 1 Second Linear layer:\n",
      " Output Encrypted: True\n",
      "\n",
      "Rank: 0 Shares after Second Linear layer:tensor([[ 1227074762242099399, -7183963528918030511],\n",
      "        [ 1227101465106941139, -7183869735820509603],\n",
      "        [ 1226927389649283233, -7183833265029392183]])\n",
      "\n",
      "Rank: 1 Shares after Second Linear layer:tensor([[-1227074762242272097,  7183963528918200002],\n",
      "        [-1227101465106801116,  7183869735820423652],\n",
      "        [-1226927389648997099,  7183833265029168808]])\n",
      "\n",
      "Decrypted output:\n",
      " Output Encrypted: False\n",
      "Tensors:\n",
      " tensor([[-2.6352,  2.5862],\n",
      "        [ 2.1366, -1.3115],\n",
      "        [ 4.3661, -3.4084]])\n"
     ]
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def step_through_two_layers():    \n",
    "    rank = comm.get().get_rank()\n",
    "\n",
    "    # Load and encrypt the network\n",
    "    model = crypten.load_from_party(sample_trained_model_file, model_class=AliceNet, src=ALICE)\n",
    "    private_model = crypten.nn.from_pytorch(model, dummy_input=torch.empty((1, 50)))\n",
    "    private_model.encrypt(src=ALICE)\n",
    "\n",
    "    # Load and encrypt the data\n",
    "    data_enc = crypten.load(sample_data_bob_file, src=BOB)\n",
    "\n",
    "    # Forward through the first layer\n",
    "    out_enc = private_model._modules['5'].forward(data_enc)\n",
    "    encrypted = crypten.is_encrypted_tensor(out_enc)\n",
    "    crypten.print(f\"Rank: {rank}\\n\\tFirst Linear Layer: Output Encrypted: {encrypted}\", in_order=True)\n",
    "    crypten.print(f\"Rank: {rank}\\n\\tShares after First Linear Layer:{out_enc.share}\", in_order=True)\n",
    "\n",
    "    # Apply ReLU activation\n",
    "    out_enc = private_model._modules['6'].forward(out_enc)\n",
    "    encrypted = crypten.is_encrypted_tensor(out_enc)\n",
    "    crypten.print(f\"Rank: {rank}\\n\\tReLU:\\n Output Encrypted: {encrypted}\", in_order=True)\n",
    "    crypten.print(f\"Rank: {rank}\\n\\tShares after ReLU: {out_enc.share}\\n\", in_order=True)\n",
    "\n",
    "    # Forward through the second Linear layer\n",
    "    out_enc = private_model._modules['output'].forward(out_enc)\n",
    "    encrypted = crypten.is_encrypted_tensor(out_enc)\n",
    "    crypten.print(f\"Rank: {rank} Second Linear layer:\\n Output Encrypted: {encrypted}\\n\", in_order=True) \n",
    "    crypten.print(f\"Rank: {rank} Shares after Second Linear layer:{out_enc.share}\\n\", in_order=True)\n",
    "\n",
    "    # Decrypt the output\n",
    "    out_dec = out_enc.get_plain_text()\n",
    "    \n",
    "    # Since both parties have same decrypted results, only print the rank 0 output\n",
    "    crypten.print(\"Decrypted output:\\n Output Encrypted:\", crypten.is_encrypted_tensor(out_dec))\n",
    "    crypten.print(\"Tensors:\\n\", out_dec)\n",
    "    \n",
    "z = step_through_two_layers()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, we emphasize that the output of each layer is an encrypted tensor. Only after the final call to `get_plain_text` do we get the plaintext tensor.\n",
    "\n",
    "### From PyTorch to CrypTen: Structural Changes in Network Architecture \n",
    "\n",
    "We have used a simple two-layer network in the above example, but the same ideas apply to more complex networks and operations. However, in more complex networks, there may not always be a one-to-one mapping between the PyTorch layers and the CrypTen layers. This is because we use PyTorch's onnx implementation to convert PyTorch models to CrypTen models. \n",
    "As an example, we'll take a typical network used to classify digits in MNIST data, and look at what happens to its structure we convert it to a CrypTen module. (As we only wish to illustrate the structural changes in layers, we will not train this network on data; we will just use it with its randomly initialized weights). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name: 46 \tModule: Conv2d encrypted module\n",
      "Name: 26 \tModule: ReLU encrypted module\n",
      "Name: 27 \tModule: _ConstantPad encrypted module\n",
      "Name: 28 \tModule: AvgPool2d encrypted module\n",
      "Name: 49 \tModule: Conv2d encrypted module\n",
      "Name: 31 \tModule: ReLU encrypted module\n",
      "Name: 32 \tModule: _ConstantPad encrypted module\n",
      "Name: 33 \tModule: AvgPool2d encrypted module\n",
      "Name: 34 \tModule: Shape encrypted module\n",
      "Name: 36 \tModule: Gather encrypted module\n",
      "Name: 37 \tModule: Constant encrypted module\n",
      "Name: 38 \tModule: Unsqueeze encrypted module\n",
      "Name: 39 \tModule: Unsqueeze encrypted module\n",
      "Name: 40 \tModule: Concat encrypted module\n",
      "Name: 41 \tModule: Reshape encrypted module\n",
      "Name: 42 \tModule: Linear encrypted module\n",
      "Name: 43 \tModule: _BatchNorm encrypted module\n",
      "Name: 44 \tModule: ReLU encrypted module\n",
      "Name: output \tModule: Linear encrypted module\n"
     ]
    }
   ],
   "source": [
    "# Define Alice's network\n",
    "class AliceNet2(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AliceNet2, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)\n",
    "        self.conv2 = nn.Conv2d(16, 16, kernel_size=5, padding=0)\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 100)\n",
    "        self.fc2 = nn.Linear(100, 10)\n",
    "        self.batchnorm1 = nn.BatchNorm2d(16)\n",
    "        self.batchnorm2 = nn.BatchNorm2d(16)\n",
    "        self.batchnorm3 = nn.BatchNorm1d(100)\n",
    " \n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.batchnorm1(out)\n",
    "        out = F.relu(out)\n",
    "        out = F.avg_pool2d(out, 2)\n",
    "        out = self.conv2(out)\n",
    "        out = self.batchnorm2(out)\n",
    "        out = F.relu(out)\n",
    "        out = F.avg_pool2d(out, 2)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        out = self.batchnorm3(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "    \n",
    "model = AliceNet2()\n",
    "\n",
    "# Let's encrypt the complex network. \n",
    "# Create dummy input of the correct input shape for the model\n",
    "dummy_input = torch.empty((1, 1, 28, 28))\n",
    "\n",
    "# Encrypt the network\n",
    "private_model = crypten.nn.from_pytorch(model, dummy_input)\n",
    "private_model.encrypt(src=ALICE)\n",
    "\n",
    "# Examine the structure of the encrypted network\n",
    "for name, curr_module in private_model._modules.items():\n",
    "    print(\"Name:\", name, \"\\tModule:\", curr_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the CrypTen network has split some the layers in the PyTorch module into several CrypTen modules. Each PyTorch operation may correspond to one or more operations in CrypTen. However, during the conversion, these are sometimes split due to limitations intorduced by onnx.\n",
    "\n",
    "Before exiting this tutorial, please clean up the files generated using the following code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "for fn in temp_files:\n",
    "    if os.path.exists(fn): os.remove(fn)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
